{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Theano 条件语句"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`theano` 中提供了两种条件语句，`ifelse` 和 `switch`，两者都是用于在符号变量上使用条件语句：\n",
    "\n",
    "- `ifelse(condition, var1, var2)`\n",
    "    - 如果 `condition` 为 `true`，返回 `var1`，否则返回 `var2`\n",
    "- `switch(tensor, var1, var2)`\n",
    "    - Elementwise `ifelse` 操作，更一般化\n",
    "- `switch` 会计算两个输出，而 `ifelse` 只会根据给定的条件，计算相应的输出。\n",
    "\n",
    "`ifelse` 需要从 `theano.ifelse` 中导入，而 `switch` 在 `theano.tensor` 模块中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using gpu device 1: Tesla K10.G2.8GB (CNMeM is disabled)\n"
     ]
    }
   ],
   "source": [
    "import theano, time\n",
    "import theano.tensor as T\n",
    "import numpy as np\n",
    "from theano.ifelse import ifelse"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "假设我们有两个标量参数：$a, b$，和两个矩阵 $\\mathbf{x, y}$，定义函数为：\n",
    "\n",
    "$$ \n",
    "\\mathbf z = f(a, b,\\mathbf{x, y}) = \\left\\{ \n",
    "\\begin{aligned}\n",
    "    \\mathbf x & ,\\ a <= b\\\\\n",
    "    \\mathbf y & ,\\ a > b\n",
    "\\end{aligned}\n",
    "\\right.\n",
    "$$\n",
    "\n",
    "定义变量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "a, b = T.scalars('a', 'b')\n",
    "x, y = T.matrices('x', 'y')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用 `ifelse` 构造，小于等于用 `T.lt()`，大于等于用 `T.gt()`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "z_ifelse = ifelse(T.lt(a, b), x, y)\n",
    "\n",
    "f_ifelse = theano.function([a, b, x, y], z_ifelse)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用 `switch` 构造："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "z_switch = T.switch(T.lt(a, b), x, y)\n",
    "\n",
    "f_switch = theano.function([a, b, x, y], z_switch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "val1 = 0.\n",
    "val2 = 1.\n",
    "big_mat1 = np.ones((10000, 1000), dtype=theano.config.floatX)\n",
    "big_mat2 = np.ones((10000, 1000), dtype=theano.config.floatX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "比较两者的运行速度："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time spent evaluating both values 0.638598 sec\n",
      "time spent evaluating one value 0.461249 sec\n"
     ]
    }
   ],
   "source": [
    "n_times = 10\n",
    "\n",
    "tic = time.clock()\n",
    "for i in xrange(n_times):\n",
    "    f_switch(val1, val2, big_mat1, big_mat2)\n",
    "print 'time spent evaluating both values %f sec' % (time.clock() - tic)\n",
    "\n",
    "tic = time.clock()\n",
    "for i in xrange(n_times):\n",
    "    f_ifelse(val1, val2, big_mat1, big_mat2)\n",
    "print 'time spent evaluating one value %f sec' % (time.clock() - tic)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
